"""
Given a pretrained contexteval model, extract the Calypso BiLM contextualizer
and compare each layer with a pre-trained Calypso BiLM contextualizer.
"""
import argparse
import logging
import os
import sys

from allennlp.models import load_archive
import torch

sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
import contexteval  # noqa:F401

logger = logging.getLogger(__name__)


def main():
    # Load generated model file
    archive = load_archive(args.archive_path)
    model = archive.model
    finetuned_calypso_state_dict = model._language_model.state_dict()

    # Load calypso weights file
    original_calypso_state_dict = torch.load(args.weight_file)

    # Get the average parameter shift in the character encoder
    character_encoder_total_shift = 0.0
    character_encoder_num_params = 0.0
    for key, parameter in finetuned_calypso_state_dict.items():
        if "_character_encoder" in key:
            character_encoder_num_params += parameter.numel()
            character_encoder_total_shift += torch.abs(parameter - original_calypso_state_dict[key]).sum().item()
    logger.info("Average Shift (L1 distance) in token embedder: {}".format(
        character_encoder_total_shift / character_encoder_num_params))

    # Get the average parameter shift in the first layer of the LSTM.
    layer_0_total_shift = 0.0
    layer_0_num_params = 0.0
    for key, parameter in finetuned_calypso_state_dict.items():
        if "backward_layer_0" in key or "forward_layer_0" in key:
            layer_0_num_params += parameter.numel()
            layer_0_total_shift += torch.abs(parameter - original_calypso_state_dict[key]).sum().item()
    logger.info("Average Shift (L1 distance) in LSTM Layer 0: {}".format(layer_0_total_shift / layer_0_num_params))

    # Get the average parameter shift in the second layer of the LSTM.
    layer_1_total_shift = 0.0
    layer_1_num_params = 0.0
    for key, parameter in finetuned_calypso_state_dict.items():
        if "backward_layer_1" in key or "forward_layer_1" in key:
            layer_1_num_params += parameter.numel()
            layer_1_total_shift += torch.abs(parameter - original_calypso_state_dict[key]).sum().item()
    logger.info("Average Shift (L1 distance) in LSTM Layer 1: {}".format(layer_1_total_shift / layer_1_num_params))


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s - %(levelname)s "
                        "- %(name)s - %(message)s",
                        level=logging.INFO)
    parser = argparse.ArgumentParser(
        description=("Given a path to a model with a fine-tuned Calypso BiLM, "
                     "report the parameter shift for each layer as compared to a "
                     "pre-trained ELMo model."),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--archive-path", type=str,
                        help=("Path to a model.tar.gz file generated by AllenNLP."))
    parser.add_argument('--weight-file', type=str,
                        help=('The path to the calypso weight file containing a '
                              'serialized state dict.'))
    args = parser.parse_args()
    main()
